"""S3 bucket statistics."""
import datetime
import os

import boto3
from cki_lib import misc
from cki_lib import yaml
from cki_lib.cronjob import CronJob
from cki_lib.logger import get_logger
import prometheus_client

from cki.cki_tools import _utils

LOGGER = get_logger(__name__)
S3_CONFIG = yaml.load(contents=os.environ.get('S3_CONFIG', ''))


class S3BucketMetrics(CronJob):
    """Calculate S3 bucket metrics."""

    schedule = '5 0 * * *'  # once a day

    metric_file_count = prometheus_client.Gauge(
        'cki_s3_bucket_file_count',
        'Number of files in the S3 bucket',
        ['name', 'description']
    )

    metric_bytes = prometheus_client.Gauge(
        'cki_s3_bucket_bytes',
        'Size of the bucket in bytes',
        ['name', 'description']
    )

    def s3_stats_from_bucket(self, name: str, description: str) -> None:
        """Get S3 bucket statistics from the objects in the bucket."""
        bucket_spec = _utils.parse_bucket_spec(os.environ[name])
        bucket = _utils.S3Bucket(bucket_spec).bucket
        size = 0
        count = 0
        for bucket_object in bucket.objects.filter(Prefix=bucket_spec.prefix):
            size += bucket_object.size
            count += 1
        self.metric_file_count.labels(name, description).set(count)
        self.metric_bytes.labels(name, description).set(size)

    def s3_stats_from_cloudwatch(self, prefix=None):
        """Get S3 bucket statistics from CloudWatch."""
        start_time = datetime.datetime.combine(
            misc.now_tz_utc(), datetime.time.min) - datetime.timedelta(days=1)
        session = boto3.Session()
        buckets = {b['Name'] for b in session.client('s3').list_buckets()['Buckets']
                   if not prefix or b['Name'].startswith(prefix)}
        metrics = {
            f'{metric[0].lower()}_{index}': {
                'metric': metric[0],
                'type': metric[1],
                'bucket': bucket,
            }
            for index, bucket in enumerate(buckets)
            for metric in (
                ('BucketSizeBytes', 'StandardStorage'),
                ('NumberOfObjects', 'AllStorageTypes'),
            )
        }
        response = session.client('cloudwatch').get_metric_data(
            MetricDataQueries=[{
                'Id': metric_id,
                'MetricStat': {
                    'Metric': {
                        'Namespace': 'AWS/S3',
                        'MetricName': metric['metric'],
                        'Dimensions': [
                            {'Name': 'BucketName', 'Value': metric['bucket']},
                            {'Name': 'StorageType', 'Value': metric['type']},
                        ],
                    },
                    'Period': 86400,
                    'Stat': 'Average',
                },
            } for metric_id, metric in metrics.items()],
            StartTime=start_time,
            EndTime=start_time + datetime.timedelta(days=1),
        )
        for result in response['MetricDataResults']:
            metric = metrics[result['Id']]
            if metric['metric'] == 'BucketSizeBytes':
                self.metric_bytes.labels(metric['bucket'], metric['bucket']).set(
                    misc.get_nested_key(result, 'Values/0', 0))
            else:
                self.metric_file_count.labels(metric['bucket'], metric['bucket']).set(
                    misc.get_nested_key(result, 'Values/0', 0))

    def run(self, **_):
        """Update the bucket metrics."""
        for bucket in S3_CONFIG.get('buckets', []):
            with misc.only_log_exceptions():
                self.s3_stats_from_bucket(bucket['name'], bucket['description'])
        if aws_config := S3_CONFIG.get('aws'):
            self.s3_stats_from_cloudwatch(aws_config.get('prefix'))
